source("function_import.R")
|
| | 0%
|
|...... | 9%
|
|............ | 18%
|
|.................. | 27%
|
|........................ | 36%
|
|.............................. | 45%
|
|................................... | 55%
|
|......................................... | 64%
|
|............................................... | 73%
|
|..................................................... | 82%
|
|........................................................... | 91%
|
|.................................................................| 100%
|
| | 0%
|
|. | 2%
|
|.. | 3%
|
|... | 5%
|
|.... | 7%
|
|..... | 8%
|
|...... | 10%
|
|....... | 11%
|
|......... | 13%
|
|.......... | 15%
|
|........... | 16%
|
|............ | 18%
|
|............. | 20%
|
|.............. | 21%
|
|............... | 23%
|
|................ | 25%
|
|................. | 26%
|
|.................. | 28%
|
|................... | 30%
|
|.................... | 31%
|
|..................... | 33%
|
|...................... | 34%
|
|....................... | 36%
|
|......................... | 38%
|
|.......................... | 39%
|
|........................... | 41%
|
|............................ | 43%
|
|............................. | 44%
|
|.............................. | 46%
|
|............................... | 48%
|
|................................ | 49%
|
|................................. | 51%
|
|.................................. | 52%
|
|................................... | 54%
|
|.................................... | 56%
|
|..................................... | 57%
|
|...................................... | 59%
|
|....................................... | 61%
|
|........................................ | 62%
|
|.......................................... | 64%
|
|........................................... | 66%
|
|............................................ | 67%
|
|............................................. | 69%
|
|.............................................. | 70%
|
|............................................... | 72%
|
|................................................ | 74%
|
|................................................. | 75%
|
|.................................................. | 77%
|
|................................................... | 79%
|
|.................................................... | 80%
|
|..................................................... | 82%
|
|...................................................... | 84%
|
|....................................................... | 85%
|
|........................................................ | 87%
|
|.......................................................... | 89%
|
|........................................................... | 90%
|
|............................................................ | 92%
|
|............................................................. | 93%
|
|.............................................................. | 95%
|
|............................................................... | 97%
|
|................................................................ | 98%
|
|.................................................................| 100%
|
| | 0%
|
|... | 5%
|
|...... | 10%
|
|......... | 14%
|
|............ | 19%
|
|............... | 24%
|
|................... | 29%
|
|...................... | 33%
|
|......................... | 38%
|
|............................ | 43%
|
|............................... | 48%
|
|.................................. | 52%
|
|..................................... | 57%
|
|........................................ | 62%
|
|........................................... | 67%
|
|.............................................. | 71%
|
|.................................................. | 76%
|
|..................................................... | 81%
|
|........................................................ | 86%
|
|........................................................... | 90%
|
|.............................................................. | 95%
|
|.................................................................| 100%
|
| | 0%
|
|................................ | 50%
|
|.................................................................| 100%
|
| | 0%
|
|.... | 7%
|
|......... | 13%
|
|............. | 20%
|
|................. | 27%
|
|...................... | 33%
|
|.......................... | 40%
|
|.............................. | 47%
|
|................................... | 53%
|
|....................................... | 60%
|
|........................................... | 67%
|
|................................................ | 73%
|
|.................................................... | 80%
|
|........................................................ | 87%
|
|............................................................. | 93%
|
|.................................................................| 100%
|
| | 0%
|
|.................................................................| 100%
|
| | 0%
|
|.. | 3%
|
|.... | 6%
|
|...... | 9%
|
|........ | 12%
|
|.......... | 15%
|
|............ | 18%
|
|.............. | 21%
|
|................ | 24%
|
|.................. | 27%
|
|.................... | 30%
|
|...................... | 33%
|
|........................ | 36%
|
|.......................... | 39%
|
|............................ | 42%
|
|.............................. | 45%
|
|................................ | 48%
|
|................................. | 52%
|
|................................... | 55%
|
|..................................... | 58%
|
|....................................... | 61%
|
|......................................... | 64%
|
|........................................... | 67%
|
|............................................. | 70%
|
|............................................... | 73%
|
|................................................. | 76%
|
|................................................... | 79%
|
|..................................................... | 82%
|
|....................................................... | 85%
|
|......................................................... | 88%
|
|........................................................... | 91%
|
|............................................................. | 94%
|
|............................................................... | 97%
|
|.................................................................| 100%
|
| | 0%
|
|..... | 8%
|
|.......... | 15%
|
|............... | 23%
|
|.................... | 31%
|
|......................... | 38%
|
|.............................. | 46%
|
|................................... | 54%
|
|........................................ | 62%
|
|............................................. | 69%
|
|.................................................. | 77%
|
|....................................................... | 85%
|
|............................................................ | 92%
|
|.................................................................| 100%
trying URL 'https://cran.rstudio.com/src/contrib/plotly_4.9.2.1.tar.gz'
Content type 'application/x-gzip' length 3709741 bytes (3.5 MB)
==================================================
downloaded 3.5 MB
trying URL 'https://cran.rstudio.com/src/contrib/dotwhisker_0.5.0.tar.gz'
Content type 'application/x-gzip' length 935078 bytes (913 KB)
==================================================
downloaded 913 KB
trying URL 'https://cran.rstudio.com/src/contrib/broom_0.7.0.tar.gz'
Content type 'application/x-gzip' length 604195 bytes (590 KB)
==================================================
downloaded 590 KB
* installing *source* package ‘plotly’ ...
** package ‘plotly’ successfully unpacked and MD5 sums checked
** using staged installation
** R
** data
*** moving datasets to lazyload DB
** demo
** inst
** byte-compile and prepare package for lazy loading
** help
*** installing help indices
*** copying figures
** building package indices
** testing if installed package can be loaded from temporary location
** testing if installed package can be loaded from final location
** testing if installed package keeps a record of temporary installation path
* DONE (plotly)
* installing *source* package ‘broom’ ...
** package ‘broom’ successfully unpacked and MD5 sums checked
** using staged installation
** R
** inst
** byte-compile and prepare package for lazy loading
** help
*** installing help indices
*** copying figures
** building package indices
** installing vignettes
** testing if installed package can be loaded from temporary location
** testing if installed package can be loaded from final location
** testing if installed package keeps a record of temporary installation path
* DONE (broom)
* installing *source* package ‘dotwhisker’ ...
** package ‘dotwhisker’ successfully unpacked and MD5 sums checked
** using staged installation
** R
** inst
** byte-compile and prepare package for lazy loading
** help
*** installing help indices
** building package indices
** installing vignettes
** testing if installed package can be loaded from temporary location
** testing if installed package can be loaded from final location
** testing if installed package keeps a record of temporary installation path
* DONE (dotwhisker)
The downloaded source packages are in
‘/tmp/RtmpQggxRn/downloaded_packages’
trying URL 'https://cran.rstudio.com/src/contrib/plotly_4.9.2.1.tar.gz'
Content type 'application/x-gzip' length 3709741 bytes (3.5 MB)
==================================================
downloaded 3.5 MB
* installing *source* package ‘plotly’ ...
** package ‘plotly’ successfully unpacked and MD5 sums checked
** using staged installation
** R
** data
*** moving datasets to lazyload DB
** demo
** inst
** byte-compile and prepare package for lazy loading
** help
*** installing help indices
*** copying figures
** building package indices
** testing if installed package can be loaded from temporary location
** testing if installed package can be loaded from final location
** testing if installed package keeps a record of temporary installation path
* DONE (plotly)
The downloaded source packages are in
‘/tmp/RtmpQggxRn/downloaded_packages’
trying URL 'https://cran.rstudio.com/src/contrib/dotwhisker_0.5.0.tar.gz'
Content type 'application/x-gzip' length 935078 bytes (913 KB)
==================================================
downloaded 913 KB
* installing *source* package ‘dotwhisker’ ...
** package ‘dotwhisker’ successfully unpacked and MD5 sums checked
** using staged installation
** R
** inst
** byte-compile and prepare package for lazy loading
** help
*** installing help indices
** building package indices
** installing vignettes
** testing if installed package can be loaded from temporary location
** testing if installed package can be loaded from final location
** testing if installed package keeps a record of temporary installation path
* DONE (dotwhisker)
The downloaded source packages are in
‘/tmp/RtmpQggxRn/downloaded_packages’
trying URL 'https://cran.rstudio.com/src/contrib/broom_0.7.0.tar.gz'
Content type 'application/x-gzip' length 604195 bytes (590 KB)
==================================================
downloaded 590 KB
* installing *source* package ‘broom’ ...
** package ‘broom’ successfully unpacked and MD5 sums checked
** using staged installation
** R
** inst
** byte-compile and prepare package for lazy loading
** help
*** installing help indices
*** copying figures
** building package indices
** installing vignettes
** testing if installed package can be loaded from temporary location
** testing if installed package can be loaded from final location
** testing if installed package keeps a record of temporary installation path
* DONE (broom)
The downloaded source packages are in
‘/tmp/RtmpQggxRn/downloaded_packages’
trying URL 'https://cran.rstudio.com/src/contrib/plotly_4.9.2.1.tar.gz'
Content type 'application/x-gzip' length 3709741 bytes (3.5 MB)
==================================================
downloaded 3.5 MB
* installing *source* package ‘plotly’ ...
** package ‘plotly’ successfully unpacked and MD5 sums checked
** using staged installation
** R
** data
*** moving datasets to lazyload DB
** demo
** inst
** byte-compile and prepare package for lazy loading
** help
*** installing help indices
*** copying figures
** building package indices
** testing if installed package can be loaded from temporary location
** testing if installed package can be loaded from final location
** testing if installed package keeps a record of temporary installation path
* DONE (plotly)
The downloaded source packages are in
‘/tmp/RtmpQggxRn/downloaded_packages’
Purpose. In this work, we will explore the relation between identified measures of despair of interest (e.g., personality measures of self-consciousness, individual and composite item scores from the CES-D assessment) and descriptors of diseases of despair. We will achieve this goal through modeling the outcomes based on the included predictors, and robustly assess the importance of the included features in predicting the outcomes via bootstrapping. We will use two well-known machine learning models, random forests and LASSO, which are both frequently used to measure the relative importance of the predictors included in the models. Lastly, we’ll generate trained and tuned models using this reduced feature set which can be used by others wish to predict the identified outcomes.
Subject inclusion. For this investigation, we will omit the entirety of Wave 2. This is commonly done in analyses of AddHealth data due the design of the original study. Otherwise, our dataset will include only subjects who have predictor and outcome data in all of the waves.
Outcome variables. We will model problematic drinking assessed at Wave 5.
Predictor variables. The predictors for these models are hand-picked, and based on previous work, relevance, and subject matter expertise. The set of predictors and the set of outcomes are disjoint. Predictors from Waves 1-5 (excluding Wave 2, see above) are included, and will be detailed in the following analysis.
seed= 3895
set.seed(seed)
The predictors we will be using will be the the variable predictor_list loaded from 10-import-data.Rmd file. These initial set of predictors will be based of the list of variables that describe anxiety, depression, and optimism.
## set outcome variable of interest
outcome = 'h5to15'
wave_data <- load_waves(1:5)
full_dataset <- get_working_dataset_full(wave_data, join_type = 'full')
## Only study the subjects that we're interested in.
inner_aids <- get_inner(list(wave_data[[1]], wave_data[[3]], wave_data[[4]], wave_data[[5]]))
Warning messages:
1: In grDevices::png(f) : unable to open connection to X11 display ''
2: In grDevices::png(f) : unable to open connection to X11 display ''
## get na_levels : dataset to recode all skip levels in variables
na_levels <- read_csv("na_levels.csv")
Parsed with column specification:
cols(
variable = [31mcol_character()[39m,
wave = [32mcol_double()[39m,
na_level = [32mcol_double()[39m,
type = [31mcol_character()[39m
)
## use the features and ids that you want to select out what you want
ds <- full_dataset %>%
filter(aid %in% inner_aids) %>%
remove_subjects_not_in_wave1() %>%
add_demographics() %>%
add_bio_despair() %>%
dplyr::select(aid, outcome, all_of(c(predictor_list, demographic_age_list, demographic_list))) %>%
dplyr::select(-c(h5waist,h5bmi,h5dbp,h5bpjcls,h5bpcls4,h5sbp)) %>%
recode_missing_levels(na_levels)
0 subjects removed from dataset.
[1] "Recoding Missing Factor Variables"
[1] "Factor variables being recoded : 63"
Unknown levels in `f`: 9Unknown levels in `f`: 9Unknown levels in `f`: 6
[1] "Recoding Missing Numeric Variables"
[1] "Numeric variables being recoded : 6"
This is the table with all the counts for different levels of the variables. We can see 97 is a legitimate skip variable. We can send this into the binarize_var function and it will be coded as 0 (not responding to questions involving problematic drinking).
ds %>%
group_by(get(outcome)) %>%
summarise(total = n(), type = class(get(outcome)))
Warning messages:
1: In grDevices::png(f) : unable to open connection to X11 display ''
2: In grDevices::png(f) : unable to open connection to X11 display ''
ds %>%
binarize_var(outcome, legit_skip = T, skip_var = 97) %>%
group_by(get(outcome)) %>%
summarise(total = n(), type = class(get(outcome)))
Warning messages:
1: In grDevices::png(f) : unable to open connection to X11 display ''
2: In grDevices::png(f) : unable to open connection to X11 display ''
ds %>% explore_outcome(outcome, binary = F, legit_skip = T, skip_var = 97)
This shows the result of recoding; 0 and legitimate skip (97) are recoded as the negative class, and the rest are regarded as the positive class. The results are consistent with the previous cell; 1372 97s + 3802 0s = 5174 total negative class. This fully accounts for the 9,349 subjects in the current dataframe.
ds <- ds %>%
binarize_var(outcome, legit_skip =T, skip_var = 97) %>%
drop_na(outcome)
Warning messages:
1: In grDevices::png(f) : unable to open connection to X11 display ''
2: In grDevices::png(f) : unable to open connection to X11 display ''
Now that the dataframe has been binarized and the null 11 outcomes dropped, There are 9338 subjects in the dataset with 76 total predictors, aid, and the outcome variable.
ds %>% explore_outcome(outcome)
The distribution of the data is a slight imbalance, but not by much. I think this warrants using AUC as the optimization and selection metric.
Here, we comment about the general characteristics of the data based on the provided visualizations. We comment on missingness of data, any strange or unusual behavior (e.g., strong imbalances), and any correlation that sticks out.
#Report about the characteristics of the subjects left out of the join
ds %>% explore_dropped()
# Visualize distributions of variables of interest
ds %>%
dplyr::select(-aid) %>%
graph_bar_discrete(df = .,
plot_title = "Distributions of Discrete Variables",
max_categories = 50,
num_rows = 3,
num_cols = 3,
x_axis_size = 12,
y_axis_size = 12,
title_size = 15)
ds %>%
graph_missing(only_missing = TRUE,
title = "Percent Missing",
box_line_size = .5,
label_size = .5,
x_axis_size = 12,
y_axis_size = 12,
title_size = 15)
ds %>%
#dplyr::select(1:20) %>%
pairwise_cramers_v() %>%
plot_cramer_v(x_axis_angle = 90,
plot_title = "Association among Categorical Variables",
interactive = TRUE)
Warning messages:
1: In grDevices::png(f) : unable to open connection to X11 display ''
2: In grDevices::png(f) : unable to open connection to X11 display ''
We can see that some of the categorical variables actually have non-trivial correlation with others. It appears as if entire blocks of predictors have high correlation among themselves. (h4id5j, h4id5h) have correlation of 0.6, (h3id5j, h4mh26) have correlation of 0.51, and correlations among the h4mh* and h1fs* variables are near 0.5 or higher. This may result in importance sharing in the future, and we may need to decide to remove some of these variables.
In this section, we split the data to ensure that our model is able to generalize to other datasets.
## split the data into relevant proportions desired
data_splits <- ds %>%
split_data(strat_var = all_of(outcome), ratios=c(0.7, 0.2, 0.1))
Warning messages:
1: In grDevices::png(f) : unable to open connection to X11 display ''
2: In grDevices::png(f) : unable to open connection to X11 display ''
# assemble list
training_df <- data_splits$train %>% dplyr::select(-aid)
validation_df <- data_splits$valid %>% dplyr::select(-aid)
testing_df <- data_splits$test %>% dplyr::select(-aid)
In this section, we explore the importance of features using bootstrapping.
The RF models are chosen based on a grid search using the following the parameters:
rf_params <- list(max_depth = 50,
ntrees = 150,
mtries = c(-1, 20),
min_rows = c(5,10),
stopping_rounds = 4,
balance_classes = FALSE,
stopping_metric = 'auc',
categorical_encoding = 'one_hot_explicit')
Warning messages:
1: In grDevices::png(f) : unable to open connection to X11 display ''
2: In grDevices::png(f) : unable to open connection to X11 display ''
3: In grDevices::png(f) : unable to open connection to X11 display ''
4: In grDevices::png(f) : unable to open connection to X11 display ''
5: In grDevices::png(f) : unable to open connection to X11 display ''
6: In grDevices::png(f) : unable to open connection to X11 display ''
7: In grDevices::png(f) : unable to open connection to X11 display ''
8: In grDevices::png(f) : unable to open connection to X11 display ''
9: In grDevices::png(f) : unable to open connection to X11 display ''
10: In grDevices::png(f) : unable to open connection to X11 display ''
# define number of bootstraps
n_boot = 100
boots_rf <- model_feature_selection("RF", training_frame = training_df,
validation_frame = validation_df,
hyper_params = rf_params,
selection_metric = 'AUC',
outcome = outcome, n = n_boot, seed=seed)
The following table displays the mean performance metrics for the bootstrapped models on the validation set, removing values for which there are NA.
mean_bs_rf_perf <- get_metric_set_from_perfs(boots_rf$perfs) %>%
dplyr::select(accuracy, mpce, sens, spec, ppv, npv, roc_auc, pr_auc,
tns, tps, fns, fps, no_n, no_p, err_rate, bal_accuracy, everything()) %>%
summarise_if(is.numeric, mean, na.rm=TRUE) %>%
mutate(model = 'bs_rf') %>%
dplyr::select(model, everything())
mean_bs_rf_perf
boot_rf_mdi <- boots_rf$mdi %>%
get_median_placement(use_base_var = TRUE) %>%
add_attribute_names('predictor', full_dataset) %>%
dplyr::select(predictor, att_name, overall_rank)
Warning messages:
1: In grDevices::png(f) : unable to open connection to X11 display ''
2: In grDevices::png(f) : unable to open connection to X11 display ''
head(boot_rf_mdi, 20)
This table returns variable importance ranks that returned from each of the bootstrapped models.
# Needs to be fixed so that axes don't overlap each other and obscure understanding
plot_placement_boxplot(boots_rf$mdi)
Now, let’s look at the permutation importance:
boot_rf_perm_plt <- boots_rf$models %>%
get_aggregated_permute_imp(training_df, outcome=outcome)
Error: Can't subset columns that don't exist.
[31mx[39m Column `aid` doesn't exist.
met <- 'roc_auc'
boot_rf_perm <- boot_rf_perm_plt %>%
get_permute_placement(metric_oi=met) %>%
add_attribute_names('predictor', full_dataset) %>%
dplyr::select(predictor, everything())
head(boot_rf_perm, 20)
Now, let’s look at these models together:
cbind(boot_rf_mdi[1:20,], dplyr::select(boot_rf_perm[1:20,], -met))
plot_permute_var_imp(boot_rf_perm, metric = roc_auc)
In this step, we model the relation between the outcomes and the predictors using a linear regression with L2 regularization. This drives the importance of unimportant and redudant features towards zero.
mean_bs_lasso_perf <- get_metric_set_from_perfs(boots_lasso$perfs) %>%
dplyr::select(accuracy, mpce, sens, spec, ppv, npv, roc_auc, pr_auc,
tns, tps, fns, fps, no_n, no_p, err_rate, bal_accuracy, everything()) %>%
summarise_if(is.numeric, mean, na.rm=TRUE) %>%
mutate(model='bs_lasso') %>%
dplyr::select(model, everything())
mean_bs_lasso_perf
boot_lasso_mdi <- boots_lasso$mdi %>%
get_median_placement(use_base_var = TRUE) %>%
add_attribute_names('predictor', full_dataset) %>%
dplyr::select(predictor, att_name, overall_rank)
`summarise()` ungrouping output (override with `.groups` argument)
head(boot_lasso_mdi, 20)
plot_placement_boxplot(boots_lasso$mdi)
boot_lasso_perm_plt <- boots_lasso$models %>%
get_aggregated_permute_imp(training_df, outcome=outcome)
boot_lasso_perm <- boot_lasso_perm_plt %>%
get_permute_placement(metric_oi=met) %>% #set in random forest section
add_attribute_names('predictor', full_dataset) %>%
dplyr::select(predictor, everything())
head(boot_lasso_perm, 20)
plot_permute_var_imp(boot_lasso_perm, metric = roc_auc)
Now, we compare the feature importances generated by the two different approaches. The traditional method of evaluating feature importance for regression methods is through analysis of the coefficients.
cbind(boot_lasso_mdi[1:20,], select(boot_lasso_perm[1:20,], -met))
The following table compares the mean performance of bootstrapped random forests to the mean performance of bootstrapped LASSO methods.
bs_comp_perfs <- rbind(mean_bs_rf_perf, mean_bs_lasso_perf)
bs_comp_perfs
Here, we look at the aggregated results of the bootstrapped predictors and compare the models generated to each other.
joined_results <- boot_rf_perm %>%
dplyr::select(-met) %>%
full_join(dplyr::select(boot_lasso_perm, -met), by=c("predictor", "att_name"), suffix=c('.rf', '.lasso')) %>%
mutate(mean_rank = rowMeans(dplyr::select(., overall_rank.rf, overall_rank.lasso), na.rm=TRUE)) %>%
arrange(mean_rank)
head(joined_results, 20)
The following visualization provides the intuition about the differences in the rankings between model types. They’re ordered by the overall mean importance, and for a given variable, the differences in rank are shown.
# Comparison of top_n features
joined_results %>%
compare_feature_select(interactive = TRUE,
top_n = 100,
opacity = 0.50,
plot_title = "Permutation Importance of Predictors by Model")
In this step, we build the final model for the random forest. We use slightly more values in order to come up with the best model, keeping in mind the number of combinations that are required to run to evaluate the grid.
# # Spans of hyper parameters for random forest
rf_params <- list(max_depth = 50,
ntrees = 150,
mtries = seq(-1, 30, by=5),
min_rows = seq(5, 60, by=5),
balance_classes = c(TRUE, FALSE),
stopping_metric = 'AUCPR',
categorical_encoding = 'one_hot_explicit')
# rf_params <- list(max_depth = seq(20, 50, 20),
# balance_classes = TRUE,
# categorical_encoding= 'one_hot_explicit')
# Function parameters
final_model_rf <- rf_model(outcome,
training_frame = training_df,
validation_frame = validation_df,
nfolds = 5,
hyper_params = rf_params, model_seed=seed)
The final random forest performance metrics are shown below:
# show model final performance
print(final_model_rf[[2]])
final_rf_perm_plt <- c(final_model_rf[[1]]) %>%
get_aggregated_permute_imp(training_df, outcome=outcome)
final_rf_perm <- final_rf_perm_plt %>%
get_permute_placement(metric_oi=met) %>%
add_attribute_names('predictor', full_dataset) %>%
dplyr::select(predictor, everything())
head(final_rf_perm, 20)
plot_permute_var_imp(final_rf_perm, metric = roc_auc)
This section investigates the differences in the bootstrap results vs the features generated from the random forest final model. The following table shows the overall differences in rank.
rf_joined_results <- final_rf_perm %>%
dplyr::select(-met) %>%
full_join(dplyr::select(boot_rf_perm, -met), by=c("predictor", "att_name"), suffix=c('.final', '.bootstrap')) %>%
mutate(mean_rank = (overall_rank.final + overall_rank.bootstrap)/2) %>%
arrange(mean_rank)
head(rf_joined_results, 20)
The following plot provides visualizations for the difference in the final model rankings vs the bootstrap.
# Comparison of top_n features
rf_joined_results %>%
compare_feature_select(sel_cols = c("overall_rank.final", "overall_rank.bootstrap"),
interactive = TRUE,
top_n = 100,
opacity = 0.50,
plot_title = "Permutation Importance of Predictors: Final vs. Bootstrap")
Now, we create the final model for LASSO. There is no substantial difference between this method and the bootstrap methods, other than the data upon which the model is being built.
# Function parameters
lasso_params <- list(alpha = c(1))
final_model_lasso <- lasso_model(training_frame = training_df,
validation_frame = validation_df,
outcome = outcome,
nfolds = 5,
hyper_params = lasso_params)
The final LASSO performance metrics are shown below:
# show model final performance
print(final_model_lasso[[2]])
final_lasso_perm_plt <- c(final_model_lasso[[1]]) %>%
get_aggregated_permute_imp(training_df, outcome=outcome)
final_lasso_perm <- final_lasso_perm_plt %>%
get_permute_placement(metric_oi=met) %>%
add_attribute_names('predictor', full_dataset) %>%
dplyr::select(predictor, everything())
head(final_lasso_perm, 20)
plot_permute_var_imp(final_lasso_perm, metric = roc_auc)
This section investigates the differences in the bootstrap results vs the features generated from the LASSO final model. The following table shows the overall differences in rank.
lasso_joined_results <- final_lasso_perm %>%
dplyr::select(-met) %>%
full_join(dplyr::select(boot_lasso_perm, -met), by=c("predictor", "att_name"), suffix=c('.final', '.bootstrap')) %>%
mutate(mean_rank = (overall_rank.final + overall_rank.bootstrap)/2) %>%
arrange(mean_rank)
head(lasso_joined_results, 20)
The following plot provides visualizations for the difference in the final model rankings vs the bootstrap.
# Comparison of top_n features
lasso_joined_results %>%
compare_feature_select(sel_cols = c("overall_rank.final", "overall_rank.bootstrap"),
interactive = TRUE,
top_n = 100,
opacity = 0.50,
plot_title = "Permutation Importance of Predictors: Final vs. Bootstrap")
Here, we compare the features generated by the permutation importance between the two final models.
rf_lasso_final_joined_results <- final_rf_perm %>%
dplyr::select(-met) %>%
full_join(dplyr::select(final_lasso_perm, -met), by=c("predictor", "att_name"), suffix=c('.rf', '.lasso')) %>%
mutate(mean_rank = (overall_rank.rf+overall_rank.lasso)/2) %>%
arrange(mean_rank)
head(rf_lasso_final_joined_results, 20)
The following visualization provides the intuition about the differences in the rankings between the final model types. They’re ordered by the overall mean importance, and for a given variable, the differences in rank are shown.
# Comparison of top_n features
rf_lasso_final_joined_results %>%
compare_feature_select(sel_cols = c("overall_rank.rf", "overall_rank.lasso"),
interactive = TRUE,
top_n = 100,
opacity = 0.50,
plot_title = "Permutation Importance of Predictors: Random Forest vs Lasso")
With the final models generated, we’re now able to compare their performance metrics.
# Comparison of performance metrics
valid_perf <- get_metric_set_from_perfs(perf_list = list(final_model_rf[[2]], final_model_lasso[[2]])) %>%
mutate(model = c('rf', 'lasso'))
testing_perf <- get_metric_set_from_models(testing_df, list(final_model_rf[[1]], final_model_lasso[[1]]), out=outcome) %>%
mutate(model = c('rf', 'lasso'))
`summarise()` ungrouping output (override with `.groups` argument)
`summarise()` ungrouping output (override with `.groups` argument)
Validation and selection. The following table shows the comparison between models in terms of the validation set. We can select our final model based on the best performing model according to the metric.
print(valid_perf)
Testing performance. The following shows the performance of both the models on the test set. Note that although we don’t use this test set to evaluate the final models, we can still see how our selected method would have performed.
print(testing_perf)
The following plots show a comparison between the performance of the models on the validation and test sets. Again, we don’t choose the model based on the test set, but curiosity dictates that we view this performance.
# Show plots side by side
metrics_of_interest = c('model', 'accuracy', 'bal_accuracy', 'mpce', 'sens', 'spec', 'ppv', 'npv', 'pr_auc', 'roc_auc')
valid_plt <- plot_metric_set(dplyr::select(valid_perf, all_of(metrics_of_interest)), plot_title = "Model comparison for validation set")
test_plt <- plot_metric_set(dplyr::select(testing_perf, all_of(metrics_of_interest)), plot_title = "Model comparison for testing set")
gridExtra::grid.arrange(gridExtra::arrangeGrob(valid_plt, test_plt, ncol=2, nrow=1))